#!/bin/bash

# NVIDIA 驱动安装助手脚本
# 用于帮助选择并安装合适的 NVIDIA 驱动版本

echo "========================================="
echo "NVIDIA 驱动安装助手"
echo "========================================="

# 检查是否有 GPU
echo "检查系统 GPU 信息..."
if lspci | grep -i nvidia > /dev/null 2>&1; then
    echo "✓ 检测到 NVIDIA GPU"
    lspci | grep -i nvidia
else
    echo "⚠ 未检测到 NVIDIA GPU"
    echo "请确认："
    echo "  1. 系统确实有 NVIDIA GPU"
    echo "  2. GPU 已正确连接到主板"
    read -p "是否继续安装驱动？(y/N): " -n 1 -r
    echo
    if [[ ! $REPLY =~ ^[Yy]$ ]]; then
        exit 1
    fi
fi

echo ""
echo "========================================="
echo "驱动版本选择指南"
echo "========================================="
echo ""
echo "根据您的 PyTorch CUDA 版本（12.8），推荐以下驱动版本："
echo ""
echo "【推荐选项】"
echo "  1. nvidia-utils-535-server (推荐，稳定)"
echo "     - 支持 CUDA 12.x"
echo "     - 适合服务器环境"
echo "     - 版本: 535.274.02"
echo ""
echo "  2. nvidia-utils-550-server (较新，推荐)"
echo "     - 支持 CUDA 12.x"
echo "     - 适合服务器环境"
echo "     - 版本: 550.163.01"
echo ""
echo "  3. nvidia-utils-535 (桌面环境)"
echo "     - 支持 CUDA 12.x"
echo "     - 适合桌面环境"
echo ""
echo "【其他选项】"
echo "  - nvidia-utils-545: 较新版本"
echo "  - nvidia-utils-565-server: 最新服务器版本"
echo "  - nvidia-utils-570-server: 最新服务器版本"
echo "  - nvidia-utils-580-server: 最新服务器版本"
echo ""
echo "注意:"
echo "  - 服务器环境推荐使用 -server 版本"
echo "  - 桌面环境可以使用普通版本"
echo "  - 所有列出的版本都支持 CUDA 12.x"
echo ""

# 检测是否为服务器环境
IS_SERVER=false
if [ -z "$DISPLAY" ] || [ "$XDG_SESSION_TYPE" = "tty" ]; then
    IS_SERVER=true
    echo "检测到服务器环境，推荐使用 -server 版本"
else
    echo "检测到桌面环境，可以使用普通版本或 -server 版本"
fi

echo ""
read -p "请选择驱动版本 (1=535-server, 2=550-server, 3=535, 4=自定义): " choice

case $choice in
    1)
        DRIVER_PKG="nvidia-utils-535-server"
        DRIVER_VERSION="535"
        ;;
    2)
        DRIVER_PKG="nvidia-utils-550-server"
        DRIVER_VERSION="550"
        ;;
    3)
        DRIVER_PKG="nvidia-utils-535"
        DRIVER_VERSION="535"
        ;;
    4)
        echo ""
        echo "可用的驱动版本："
        apt search nvidia-utils 2>/dev/null | grep "^nvidia-utils" | head -20
        echo ""
        read -p "请输入完整的包名（如 nvidia-utils-535-server）: " DRIVER_PKG
        ;;
    *)
        echo "无效选择，使用默认: nvidia-utils-535-server"
        DRIVER_PKG="nvidia-utils-535-server"
        DRIVER_VERSION="535"
        ;;
esac

echo ""
echo "========================================="
echo "准备安装: $DRIVER_PKG"
echo "========================================="
echo ""
echo "安装步骤："
echo "  1. 更新包列表"
echo "  2. 安装 $DRIVER_PKG"
echo "  3. 安装对应的驱动包: nvidia-driver-$DRIVER_VERSION"
echo "  4. 重启系统"
echo ""
read -p "是否继续？(y/N): " -n 1 -r
echo

if [[ ! $REPLY =~ ^[Yy]$ ]]; then
    echo "已取消安装"
    exit 0
fi

echo ""
echo "开始安装..."

# 更新包列表
echo "步骤 1/4: 更新包列表..."
sudo apt update

# 安装工具包
echo "步骤 2/4: 安装 $DRIVER_PKG..."
sudo apt install -y $DRIVER_PKG

# 安装驱动包
echo "步骤 3/4: 安装 nvidia-driver-$DRIVER_VERSION..."
sudo apt install -y nvidia-driver-$DRIVER_VERSION

echo ""
echo "========================================="
echo "安装完成！"
echo "========================================="
echo ""
echo "⚠️  重要: 需要重启系统才能使驱动生效"
echo ""
echo "重启后，运行以下命令验证："
echo "  nvidia-smi"
echo ""
read -p "是否现在重启系统？(y/N): " -n 1 -r
echo

if [[ $REPLY =~ ^[Yy]$ ]]; then
    echo "正在重启系统..."
    sudo reboot
else
    echo "请稍后手动重启系统: sudo reboot"
fi

